{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b4f81c4f-fbf7-4495-9214-106dc5a97e9a",
   "metadata": {},
   "source": [
    "# PRefLexOR Inference: Thinking and Reflection and Agentic Reasoning "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56e8c3ef-85ce-4727-8ebc-cd33b852ec0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install git+https://github.com/lamm-mit/PRefLexOR.git --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "45b15ada-6e82-423a-80e2-3267b99890c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer \n",
    "\n",
    "import torch\n",
    "\n",
    "from PRefLexOR import *\n",
    "\n",
    "# Define thinking and reflection tokens\n",
    "think_start = '<|thinking|>'\n",
    "think_end = '<|/thinking|>'\n",
    "reflect_start=\"<|reflect|>\" \n",
    "reflect_end= \"<|/reflect|>\"\n",
    "\n",
    "import transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6900773b-d722-44f6-9d69-e11a30d7d623",
   "metadata": {},
   "source": [
    "### Load model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e821e723-05d6-408c-a880-5b7697c6968a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name='lamm-mit/PRefLexOR_ORPO_DPO_EXO_REFLECT_10222024'\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name,     \n",
    "    torch_dtype =torch.bfloat16,\n",
    "    #attn_implementation=\"flash_attention_2\",\n",
    "    device_map=\"auto\", trust_remote_code=True,\n",
    "    )\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1ff7d8c-f58d-4bdd-98a0-39b6c1bb6f2e",
   "metadata": {},
   "source": [
    "### Inference: Conventional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34a7c329-7f34-445e-9d19-aca68d758d47",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "txt = 'What is the relationship between materials and music? Brief answer.' + f' Use {think_start}.'\n",
    "\n",
    "output_text, messages = generate_local_model(\n",
    "    model=model, \n",
    "    tokenizer=tokenizer, \n",
    "    prompt=txt, \n",
    "    system_prompt='',  \n",
    "    num_return_sequences=1, \n",
    "    repetition_penalty=1.0, \n",
    "    temperature=0.1, \n",
    "    max_new_tokens=2024, \n",
    "    messages=[], \n",
    "    do_sample=True\n",
    ")\n",
    "\n",
    "print(output_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d9648d6-8e0f-4182-91cf-2ecf79829865",
   "metadata": {},
   "source": [
    "#### Extract thinking or other sections from the output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e734003-b88d-4492-a701-373645658602",
   "metadata": {},
   "outputs": [],
   "source": [
    "thinking    = extract_text(output_text, thinking_start=think_start, thinking_end=think_end)[0].strip()\n",
    "reflection  = extract_text(output_text, thinking_start=reflect_start, thinking_end=reflect_end)[0].strip()\n",
    "answer_only = extract_text(output_text, thinking_start=reflect_end, thinking_end=\"NONE\").strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7f802cf-28ae-4998-a351-2cb0279bf7ce",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print (\"THINKING:\\n\\n\", thinking)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d66a217-0fc8-43ef-9cdc-9171424e0fde",
   "metadata": {},
   "outputs": [],
   "source": [
    "print (\"REFLECTION:\\n\\n\", thinking)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41df65d5-2fb0-45b2-8f65-0f01a7832dc9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print (\"ANSWER:\\n\\n\", answer_only)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c935961-de64-4914-adf0-11c003d9a25e",
   "metadata": {},
   "source": [
    "### Inference: Recursive using multi-agent reasoning using thinking and reflection tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31598fd1-bba5-4a4e-99e4-a0161ddf8942",
   "metadata": {},
   "outputs": [],
   "source": [
    "from PRefLexOR import recursive_response_from_thinking"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01b87fa6-cb89-4690-b639-b50073d4237b",
   "metadata": {},
   "source": [
    "#### Load second model that will be the critic agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e6d2430-2a8b-4f31-b21a-d2b3b5d8d8fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name_base = \"meta-llama/Llama-3.2-3B-Instruct\"\n",
    "\n",
    "critic_model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name_base, \n",
    "    torch_dtype=torch.bfloat16, \n",
    "    #attn_implementation=\"flash_attention_2\", \n",
    "    device_map=\"auto\", \n",
    "    trust_remote_code=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3605ae7-d57f-4678-a728-2bb3acf541ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_text, output_list, output_text_integrated = recursive_response(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    model_critic=critic_model,\n",
    "    tokenizer_critic=tokenizer,  #same tokenizer for critic agent as the reasoning model in our case\n",
    "    question='How do biological materials fail gracefully? Brief answer.',\n",
    "    N=2,\n",
    "    temperature=0.1,\n",
    "    temperature_improvement=0.1,\n",
    "    system_prompt='You are a helpful assistant.',\n",
    "    system_prompt_critic='You carefully improve responses, with attention to detail, and following all directions.',\n",
    "    verbatim=False,\n",
    "    thinking_start=think_start, thinking_end = think_end, \n",
    "    reflect_start=reflect_start, reflect_end= reflect_end,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c94b21f-cd26-49c6-b603-5b4c58cf0852",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i, item in enumerate(output_list):\n",
    "    answer_only=extract_text(item, thinking_start=reflect_end, thinking_end=\"NONE\")\n",
    "    print (64*\"-\"+f\"\\n>>>i={i}<<<\\n\"+64*\"-\")\n",
    "    print (answer_only)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88297285-4172-4788-91b8-9f6c526f9a9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print (\"INTEGRATED RESPONSE:\")\n",
    "print (output_text_integrated)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
